package com.sxf.mybatis.interceptor;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.StringTokenizer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import com.sxf.mybatis.dialect.Dialect;
import com.sxf.mybatis.dialect.MySQLDialect;
import com.sxf.mybatis.dialect.OracleDialect;
import org.apache.ibatis.executor.statement.RoutingStatementHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.session.Configuration;
import org.apache.ibatis.session.RowBounds;

import com.sxf.mybatis.help.ExRowBounds;
import com.sxf.mybatis.help.QueryType;

/**
 * 拦截RoutingStatementHandler的prepare方法<br>
 * 1.MyBatis根据dialect分页(需要hibernate支持):数据库分页<br>
 * 2.支持mysql统计(统计时,分页条件失效)
 * 
 * @author SXF
 */
@Intercepts({ @Signature(type = StatementHandler.class, method = "prepare", args = { Connection.class }) })
public class ProcessStatementHandlerInterceptor implements Interceptor {

	private static final Log log = LogFactory.getLog(PreparedStatement.class);
	private static Dialect dialect = null;

	@Override
	public Object intercept(Invocation invocation) throws Throwable {
		do {
			StatementHandler statementHandler = (StatementHandler) invocation
					.getTarget();
			MetaObject metaStatementHandler = SystemMetaObject
					.forObject(statementHandler);
			// StatementHandler realStatementHandler = (StatementHandler)
			// metaStatementHandler
			// .getValue("delegate");
			if (!(statementHandler instanceof RoutingStatementHandler)) {
				break;
			}
			RoutingStatementHandler routingStatementHandler = (RoutingStatementHandler) statementHandler;
			RowBounds rowBounds = (RowBounds) metaStatementHandler
					.getValue("delegate.rowBounds");
			if (rowBounds == null || rowBounds == RowBounds.DEFAULT
					|| !(rowBounds instanceof ExRowBounds)) {
				break;
			}
			Configuration configuration = (Configuration) metaStatementHandler
					.getValue("delegate.configuration");
			// Object[] objArr = invocation.getArgs();
			// Connection connection = null;
			// if (objArr != null) {
			// connection = (Connection) objArr[0];
			// }
			BoundSql boundSql = statementHandler.getBoundSql();

			if (dialect == null) {
				String databaseType = null;
				try {
					databaseType = configuration.getVariables()
							.getProperty("dialect").toUpperCase();
				} catch (Exception e) {
					// ignore
				}
				if (databaseType == null) {
					break;
				}
				dialect = getDialect(databaseType);
			}

			QueryType queryType = QueryType.OBJECT;
			ExRowBounds bounds = (ExRowBounds) rowBounds;
			queryType = bounds.getQueryType();

			switch (queryType) {
			case NUMBER: {
				// 覆盖SQL语句
				setTotalCount(metaStatementHandler, boundSql);
			}
				break;
			case OBJECT_NUMBER: {
				// 设置统计总数
				getTotalCountForConnection(configuration,
						routingStatementHandler, bounds, boundSql.getSql());

				// 设置分页
				setPagination(metaStatementHandler, rowBounds, configuration,
						boundSql, dialect);
			}
				break;
			// 默认查询QUERY_OBJECT
			default: {
				// 设置分页
				setPagination(metaStatementHandler, rowBounds, configuration,
						boundSql, dialect);
			}
			}
		} while (false);

		return invocation.proceed();
	}

	@Override
	public Object plugin(Object target) {
		return Plugin.wrap(target, this);
	}

	@Override
	public void setProperties(Properties properties) {
	}

	/**
	 * 覆盖语句,只统计总数
	 * 
	 * @param metaStatementHandler
	 * @param boundSql
	 */
	private void setTotalCount(MetaObject metaStatementHandler,
			BoundSql boundSql) {
		String originalSql = boundSql.getSql();// 原始SQL
		String desSql = getCountString(originalSql);
		metaStatementHandler.setValue("delegate.boundSql.sql", desSql);
		metaStatementHandler.setValue("delegate.rowBounds.offset",
				RowBounds.NO_ROW_OFFSET);
		metaStatementHandler.setValue("delegate.rowBounds.limit",
				RowBounds.NO_ROW_LIMIT);
	}

	/**
	 * 获取总条数
	 * 
	 * @param statementHandler
	 * @param rowBounds
	 * @param connection
	 * @param originalSql
	 * @throws SQLException
	 */
	private void getTotalCountForConnection(Configuration configuration,
			StatementHandler statementHandler, ExRowBounds bounds,
			String originalSql) throws SQLException {
		long count = 0;
		String desSql = getCountString(originalSql);
		bounds.setCountSQL(desSql);//
		Connection connection = configuration.getEnvironment().getDataSource()
				.getConnection();
		if (connection != null) {
			PreparedStatement countPreparedStatement = connection
					.prepareStatement(desSql);
			statementHandler.getParameterHandler().setParameters(
					countPreparedStatement);
			ResultSet rs = countPreparedStatement.executeQuery();
			// ResultSetMetaData rsmd = rs.getMetaData();
			if (rs.next()) {
				count = rs.getLong(1);
			}
			rs.close();
			countPreparedStatement.close();
			connection.close();
			bounds.setCount(count);

			if (log.isDebugEnabled()) {
				log.debug("===>>>>生成的统计总数SQL:"
						+ removeBreakingWhitespace(desSql));
			}
		}
	}

	/**
	 * 设置分页(若采用hibernate分页可以设置limit ?,?中的值)
	 * 
	 * @param metaStatementHandler
	 * @param rowBounds
	 * @param configuration
	 * @param boundSql
	 * @param dialect
	 * @param originalSql
	 */
	private void setPagination(MetaObject metaStatementHandler,
			RowBounds rowBounds, Configuration configuration,
			BoundSql boundSql, Dialect dialect) {
		if (rowBounds.getLimit() > 0
				&& rowBounds.getLimit() < RowBounds.NO_ROW_LIMIT) {
			String originalSql = boundSql.getSql();// 原始SQL
			String desSql = dialect.getLimitStringByHibernate(originalSql,
					rowBounds.getOffset(), rowBounds.getLimit());

			metaStatementHandler.setValue("delegate.boundSql.sql", desSql);

			String addSql = desSql.replace(originalSql, "");

			Pattern pattern = Pattern.compile("[?]");
			Matcher matcher = pattern.matcher(addSql);
			int size = 0;
			while (matcher.find()) {
				size++;
			}

			if (size == 1) {
				ParameterMapping.Builder builder = new ParameterMapping.Builder(
						configuration, "limit", Integer.class);
				List<ParameterMapping> list = new ArrayList<ParameterMapping>(
						boundSql.getParameterMappings());
				list.add(builder.build());
				MetaObject metaBoundSql = SystemMetaObject.forObject(boundSql);
				metaBoundSql.setValue("parameterMappings", list);
				// boundSql.getParameterMappings().add(builder.build());
				boundSql.setAdditionalParameter("limit", rowBounds.getLimit());
			}
			if (size == 2) {
				List<ParameterMapping> list = new ArrayList<ParameterMapping>(
						boundSql.getParameterMappings());
				MetaObject metaBoundSql = SystemMetaObject.forObject(boundSql);
				ParameterMapping.Builder builder = new ParameterMapping.Builder(
						configuration, "offset", Integer.class);
				// boundSql.getParameterMappings().add(builder.build());
				list.add(builder.build());
				builder = new ParameterMapping.Builder(configuration, "limit",
						Integer.class);
				// boundSql.getParameterMappings().add(builder.build());
				list.add(builder.build());
				boundSql.setAdditionalParameter("offset", rowBounds.getOffset());
				boundSql.setAdditionalParameter("limit", rowBounds.getLimit());
				metaBoundSql.setValue("parameterMappings", list);
			}

			metaStatementHandler.setValue("delegate.rowBounds.offset",
					RowBounds.NO_ROW_OFFSET);
			metaStatementHandler.setValue("delegate.rowBounds.limit",
					RowBounds.NO_ROW_LIMIT);

		}
	}

	/**
	 * 得到查询总数的sql
	 */
	public static String getCountString(String querySelect) {
		StringBuilder sb = new StringBuilder();
		sb.append("select count(1) from ( ");
		sb.append(querySelect);
		sb.append(" ) t");
		return sb.toString();
	}

	private String removeBreakingWhitespace(String original) {
		StringTokenizer whitespaceStripper = new StringTokenizer(original);
		StringBuilder builder = new StringBuilder();
		while (whitespaceStripper.hasMoreTokens()) {
			builder.append(whitespaceStripper.nextToken());
			builder.append(" ");
		}
		return builder.toString();
	}

	/**
	 * 获取方言
	 * 
	 * @param databaseType
	 * @return
	 */
	private Dialect getDialect(String databaseType) {
		Dialect dialect = null;
		if ("MySQL".equalsIgnoreCase(databaseType)) {
			dialect = new MySQLDialect();
		} else if ("MySQL5InnoDB".equalsIgnoreCase(databaseType)) {
			dialect = new MySQLDialect();
		} else if ("MySQLMyISAM".equalsIgnoreCase(databaseType)) {
			dialect = new MySQLDialect();
		} else if ("ORACLE9I".equalsIgnoreCase(databaseType)) {
			dialect = new OracleDialect();
		} else if ("ORACLE10g".equalsIgnoreCase(databaseType)) {
			dialect = new OracleDialect();
		} else if ("SqlServer2005".equalsIgnoreCase(databaseType)) {

		} else if ("SqlServer2008".equalsIgnoreCase(databaseType)) {

		} else if ("SqlServer".equalsIgnoreCase(databaseType)) {

		}
		return dialect;
	}
}
